import subprocess as sp
import time
from os.path import join

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm


class Trainer:
    def __init__(
        self,
        config,
        model,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        scheduler=None,
        val_set=None,
    ):
        self.model = model
        self.logger = logger
        self.train_set = train_set
        self.test_set = test_set
        if val_set:
            self.val_set = val_set
        else:
            self.val_set = None
        self.config = config
        self.cur_iter = 0
        self.batch_size = self.config["train"]["batch_size"]
        self.num_workers = self.config["train"]["num_workers"]
        self.criterion = criterion
        self.optimizer = optimizer
        self.global_iter = 0
        # self.cur_epoch = 0
        self.epoch = self.config["train"]["epochs"]
        self.log_interval = self.config["train"]["print_every"]
        self.save_every_epoch = config["train"]["save_every_epoch"]
        # TODO: identify the GPU device in the config
        # TODO: Add evaluate every epoch
        if torch.cuda.device_count() == 1:
            self.model = self.model.cuda()
            self.device = torch.device("cuda:0")
        elif torch.cuda.device_count() > 1:
            self.model = DataParallel(self.model).cuda()
            self.device = torch.device("cuda")

        self.scheduler = scheduler
        self.get_dataloaders()

    def get_gpu_memory(self):
        command = "nvidia-smi --query-gpu=memory.free --format=csv"
        memory_free_info = (
            sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
        )
        memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
        return memory_free_values

    def save_model(self, file_name):
        if ".pth" not in file_name or ".pt" not in file_name:
            file_name += ".pth"
        torch.save(
            self.model.state_dict(),
            join(self.config["general"]["save_model_dir"], file_name),
        )
        print(f"Model saved as {file_name}")

    def save_best_model(self):
        self.save_model("best")

    def save_last_model(self):
        self.save_model("last")

    def accuracy(self, logit, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        output = F.softmax(logit, dim=1)
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def get_lr(self):
        for param_group in self.optimizer.param_groups:
            return param_group["lr"]

    def prepare_data(self, data):
        try:
            inputs, labels, attributes = (
                data[0].to(self.device, non_blocking=True),
                data[1].to(self.device, non_blocking=True),
                data[2].to(self.device, non_blocking=True),
            )
            inputs, labels, attributes = (
                inputs.float(),
                labels.long(),
                attributes.long(),
            )
        except IndexError:
            inputs, labels = (
                data[0].to(self.device, non_blocking=True),
                data[1].to(self.device, non_blocking=True),
            )
            inputs, labels = inputs.float(), labels.long()
            attributes = None
        if len(labels.shape) > 1:  # TODO: Fix this
            labels = torch.squeeze(labels)
        return inputs, labels, attributes

    def compute_adjustment(self, tro=1):
        """compute the base probabilities"""

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        label_freq = {}
        for i, (inputs, target) in enumerate(self.train_loader):
            target = target.to(device)
            for j in target:
                key = int(j.item())
                label_freq[key] = label_freq.get(key, 0) + 1
        label_freq = dict(sorted(label_freq.items()))
        label_freq_array = np.array(list(label_freq.values()))
        label_freq_array = label_freq_array / label_freq_array.sum()
        adjustments = np.log(label_freq_array**tro + 1e-12)
        adjustments = torch.from_numpy(adjustments)
        adjustments = adjustments.to(device)
        return adjustments

    def run(self):
        print("==> Start training..")
        best_acc = 0.0
        print("start to compute adj")
        logit_adjustments = self.compute_adjustment()
        print("finish compute adj")
        for cur_epoch in range(self.epoch):
            self.model.train()
            epoch_loss, epoch_correct, total_num = 0.0, 0.0, 0.0
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for data in tepoch:
                    tepoch.set_description(f"Epoch {cur_epoch}")
                    inputs, labels, attributes = self.prepare_data(data)
                    self.optimizer.zero_grad()
                    outputs = self.model(inputs)
                    outputs = outputs + logit_adjustments
                    loss = self.criterion(outputs, labels)
                    loss.backward()
                    self.optimizer.step()
                    correct = (outputs.argmax(1) == labels).sum().item()
                    tepoch.set_postfix(
                        loss=loss.item(),
                        accuracy=100.0 * correct / inputs.size(0),
                        lr=self.get_lr(),
                    )
                    epoch_loss += loss
                    epoch_correct += correct
                    total_num += inputs.size(0)
                    self.global_iter += 1
                    if (
                        self.global_iter % self.config["general"]["logger"]["frequency"]
                        == 0
                    ):
                        self.logger.info(
                            f"[{cur_epoch}]/[{self.epoch}], Global Iter: {self.global_iter}, Loss: {loss:.4f}, Acc: {100.0 * correct / inputs.size(0):.4f}, lr: {self.get_lr():.6f}",
                            {
                                "cur_epoch": cur_epoch,
                                "iter": self.global_iter,
                                "loss": loss.item(),
                                "Accuracy": 100.0 * correct / inputs.size(0),
                                "lr": self.get_lr(),
                            },
                        )
            epoch_loss /= total_num
            epoch_acc = epoch_correct / total_num * 100.0
            if self.val_set:
                _ = self.evaluate(val=True)
            test_acc = self.evaluate(val=False)

            if test_acc > best_acc:
                best_acc = test_acc
                self.save_best_model()
            print(
                f"Epoch: {cur_epoch}, Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}"
            )
            self.logger.info(
                f"[{cur_epoch}]/[{self.epoch}], Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}",
                {
                    "test_epoch": cur_epoch,
                    "loss": epoch_loss.item(),
                    "Train Acc": epoch_acc,
                    "Test Acc": test_acc,
                    "Best Test Acc": best_acc,
                },
            )

            if self.scheduler:
                self.scheduler.step()
            self.save_last_model()

            if cur_epoch % self.save_every_epoch == 0:
                self.save_model(f"{cur_epoch}")

    def evaluate(self, val=True, second_model=False):
        if second_model:
            try:
                model_test = self.model2
            except Exception as e:
                print("There is no second model. Still testing the first model.")
                model_test = self.model
        else:
            model_test = self.model
        model_test.eval()
        correct, total_num, total_loss = 0.0, 0.0, 0.0
        loader = self.val_loader if val else self.test_loader
        evaluate_type = "Val" if val else "Test"
        for (
            iter,
            data,
        ) in enumerate(loader):
            inputs, labels, attributes = self.prepare_data(data)
            with torch.no_grad():
                outputs = model_test(inputs)
            total_loss += self.criterion(outputs, labels).item()
            outputs = outputs.detach().cpu()
            labels = labels.detach().cpu()
            correct += (outputs.argmax(1) == labels).sum().item()
            total_num += labels.size(0)
        acc = correct / total_num * 100
        print(f"{evaluate_type} Acc: {acc:.4f}")
        return acc

    def get_dataloaders(self):
        self.train_loader = DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )
        self.test_loader = DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )
        if self.val_set:
            self.val_loader = DataLoader(
                self.val_set,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True,
            )
        else:
            self.val_loader = None
